{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concise Logistic Regression for Image Classification\n",
    "\n",
    "- Shows a concise implementation of logistic regression for image classification\n",
    "- Uses PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from torchvision import datasets, models, transforms\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "# use gpu if available\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the data (uncomment if to download the data locally)\n",
    "#!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip\n",
    "#!unzip hymenoptera_data.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create data loaders\n",
    "\n",
    "data_dir = 'hymenoptera_data'\n",
    "\n",
    "# custom transformer to flatten the image tensors\n",
    "class ReshapeTransform:\n",
    "    def __init__(self, new_size):\n",
    "        self.new_size = new_size\n",
    "\n",
    "    def __call__(self, img):\n",
    "        result = torch.reshape(img, self.new_size)\n",
    "        return result\n",
    "\n",
    "# transformations used to standardize and normalize the datasets\n",
    "data_transforms = {\n",
    "    'train': transforms.Compose([\n",
    "        transforms.Resize(224),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        ReshapeTransform((-1,)) # flattens the data\n",
    "    ]),\n",
    "    'val': transforms.Compose([\n",
    "        transforms.Resize(224),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        ReshapeTransform((-1,)) # flattens the data\n",
    "    ]),\n",
    "}\n",
    "\n",
    "# load the correspoding folders\n",
    "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n",
    "                                          data_transforms[x])\n",
    "                  for x in ['train', 'val']}\n",
    "\n",
    "# load the entire dataset; we are not using minibatches here\n",
    "train_dataset = torch.utils.data.DataLoader(image_datasets['train'],\n",
    "                                            batch_size=len(image_datasets['train']),\n",
    "                                            shuffle=True)\n",
    "\n",
    "test_dataset = torch.utils.data.DataLoader(image_datasets['val'],\n",
    "                                           batch_size=len(image_datasets['val']),\n",
    "                                           shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the LR model\n",
    "class LR(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super(LR, self).__init__()\n",
    "        self.linear = nn.Linear(dim, 1)\n",
    "        nn.init.zeros_(self.linear.weight)\n",
    "        nn.init.zeros_(self.linear.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        x = torch.sigmoid(x)\n",
    "        return x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# predict function\n",
    "def predict(yhat, y):\n",
    "    yhat = yhat.squeeze()\n",
    "    y = y.unsqueeze(0) \n",
    "    y_prediction = torch.zeros(y.size()[1])\n",
    "    for i in range(yhat.shape[0]):\n",
    "        if yhat[i] <= 0.5:\n",
    "            y_prediction[i] = 0\n",
    "        else:\n",
    "            y_prediction[i] = 1\n",
    "    return 100 - torch.mean(torch.abs(y_prediction - y)) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model config\n",
    "dim = train_dataset.dataset[0][0].shape[0]\n",
    "\n",
    "lrmodel = LR(dim).to(device)\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = torch.optim.SGD(lrmodel.parameters(), lr=0.0001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cost after iteration 0: 0.6931471228599548 | Train Acc: 50.40983581542969 | Test Acc: 45.75163269042969\n",
      "Cost after iteration 10: 0.6691471338272095 | Train Acc: 64.3442611694336 | Test Acc: 54.24836730957031\n",
      "Cost after iteration 20: 0.6513182520866394 | Train Acc: 68.44261932373047 | Test Acc: 54.24836730957031\n",
      "Cost after iteration 30: 0.6367825269699097 | Train Acc: 68.03278350830078 | Test Acc: 54.24836730957031\n",
      "Cost after iteration 40: 0.6245337128639221 | Train Acc: 69.67213439941406 | Test Acc: 54.90196228027344\n",
      "Cost after iteration 50: 0.6139225363731384 | Train Acc: 70.90164184570312 | Test Acc: 56.20914840698242\n",
      "Cost after iteration 60: 0.6045235395431519 | Train Acc: 72.54098510742188 | Test Acc: 56.86274337768555\n",
      "Cost after iteration 70: 0.5960512161254883 | Train Acc: 74.18032836914062 | Test Acc: 57.51633834838867\n",
      "Cost after iteration 80: 0.5883084535598755 | Train Acc: 73.77049255371094 | Test Acc: 57.51633834838867\n",
      "Cost after iteration 90: 0.5811557769775391 | Train Acc: 74.59016418457031 | Test Acc: 58.1699333190918\n"
     ]
    }
   ],
   "source": [
    "# training the model\n",
    "costs = []\n",
    "\n",
    "for ITER in range(100):\n",
    "    lrmodel.train()\n",
    "    x, y = next(iter(train_dataset))\n",
    "    test_x, test_y = next(iter(test_dataset))\n",
    "\n",
    "    # forward\n",
    "    yhat = lrmodel.forward(x.to(device))\n",
    "\n",
    "    cost = criterion(yhat.squeeze(), y.type(torch.FloatTensor).to(device))\n",
    "    train_pred = predict(yhat, y)\n",
    "\n",
    "    # backward\n",
    "    optimizer.zero_grad()\n",
    "    cost.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    # evaluate\n",
    "    lrmodel.eval()\n",
    "    with torch.no_grad():\n",
    "        yhat_test = lrmodel.forward(test_x.to(device))\n",
    "        test_pred = predict(yhat_test, test_y)\n",
    "\n",
    "    if ITER % 10 == 0:\n",
    "        costs.append(cost)\n",
    "\n",
    "    if ITER % 10 == 0:\n",
    "        print(\"Cost after iteration {}: {} | Train Acc: {} | Test Acc: {}\".format(ITER, \n",
    "                                                                                    cost, \n",
    "                                                                                    train_pred,\n",
    "                                                                                    test_pred))\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### References\n",
    "- [A Logistic Regression Model from Scratch](https://colab.research.google.com/drive/1iBoJ0kngkOthy7SgVaVQA1aHEROt5mra?usp=sharing)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('play')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "cf9800998463bc980d70cdbacff0c7e9a10687346dc898569e92f016d6e252c9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
